
include("../src/qmcmary.jl")
using ..qmcmary

using Test
using ReverseDiff: jacobian, jacobian!
using LinearAlgebra

BLAS.set_num_threads(1)

function sgn_scratch(ss::ScrollSVD{T}) where T
    siz = size(ss.B[end])
    ptr = findfirst(ss.L)
    if isnothing(ptr)
        VL = Diagonal(ones(siz[1]))
        DL = VL
        UL = VL
        UR, DR, VR = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    elseif ptr == 1
        VL, DL, UL = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
        UR = Diagonal(ones(siz[1]))
        DR = UR
        VR = UR
        ##上面这个数值稳定性会突然出问题
        ##用下面这个
        ##VL DL UL UR=I DR=I VR=I -> I I I UR=VL DR=DL VR=UL
        #VL = Diagonal(ones(siz[1]))
        #DL = VL
        #UL = VL
        #UR, DR, VR = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
    else
        VL, DL, UL = ss.F[ptr].U, Diagonal(ss.F[ptr].S), ss.F[ptr].Vt
        UR, DR, VR = ss.F[ptr-1].U, Diagonal(ss.F[ptr-1].S), ss.F[ptr-1].Vt
    end
    #gtt = inv(Diagonal(ones(siz[1]))+UR*DR*VR*VL*DL*UL)
    #M = inv(UL*UR) + DR*(VR*VL)*DL
    #Fm = svd(M)
    #gtt = inv(Fm.Vt*UL)*inv(Diagonal(Fm.S))*inv(UR*Fm.U)
    #
    DLS = Diagonal(ones(Float64, siz[1]))
    DLB = Diagonal(ones(Float64, siz[1]))
    DRS = Diagonal(ones(Float64, siz[1]))
    DRB = Diagonal(ones(Float64, siz[1]))
    for i in Base.OneTo(siz[1])
        if DL[i, i] > 1.0
            DLB[i, i] = DL[i, i]
            DLS[i, i] = 1.0
        else
            DLS[i, i] = DL[i, i]
            DLB[i, i] = 1.0
        end
        if DR[i, i] > 1.0
            DRB[i, i] = DR[i, i]
            DRS[i, i] = 1.0
        else
            DRS[i, i] = DR[i, i]
            DRB[i, i] = 1.0
        end
    end
    #
    M = inv(DRB)*adjoint(UL*UR)*inv(DLB) + DRS*(VR*VL)*DLS
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    #ML = adjoint(UL)*inv(DLB)*adjoint(Fm.Vt)
    #MR = adjoint(Fm.U)*inv(DRB)*adjoint(UR)
    #gtt = ML*inv(Diagonal(Fm.S))*MR
    #增加计算phase
    sgn = det(Fm.Vt*UL)*det(UR*Fm.U)
    sgn = log(abs(sgn))
    for sval in Fm.S
        sgn += log(sval)
    end
    for sval in diag(DLB)
        sgn += log(sval)
    end
    for sval in diag(DRB)
        sgn += log(sval)
    end
    return sgn
end


function pbpU_example()
    L = 3
    Nx = 3*L^2
    nx = zeros(Nx)
    ny = zeros(Nx)
    nz = ones(Nx)
    #
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    hk = kron([1 0; 0 1], hk)
    #
    Ui = 1*ones(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=false)
    #
    sslen, bmats, allbmats1, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    ni = 1
    tapemap = pbpU_calc_tapemap(sp)
    adf_r, adf_i = pbpU_calc_xi(sp, hscfg[:, ni], ni; tapemap=tapemap)
    #更改位置i上的mu
    Ui[ni] += 0.0001
    sp = default_splitting(Nt, hk, Ui; Z2=false)
    sslen, bmats, allbmats2, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println(rdiff[1, :]/0.0001)
    #println(rdiff[2, :]/0.0001)
    #println(rdiff[3, :]/0.0001)
    #println(rdiff[4, :]/0.0001)
    println(rdiff[Nx+1, :]/0.0001)
    #println(rdiff[Nx+2, :]/0.0001)
    #println(rdiff[Nx+3, :]/0.0001)
    #println(rdiff[Nx+4, :]/0.0001)
    #
    println(adf_r[1, 1, :])
    println(adf_r[1, 2, :])
    #println(adf_i[1, 1, :])
    #println(adf_i[1, 2, :])
end


function pbpUabnm_example()
    L = 3
    Nx = 3*L^2
    nx = zeros(Nx)
    ny = ones(Nx)
    nz = zeros(Nx)
    #
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    hk = kron([1 0; 0 1], hk)
    #
    Ui = 1*ones(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=false)
    #
    sslen, bmats, allbmats1, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    ni = 1
    tapemap = pbpUabnm_calc_tapemap(sp, nx, ny, nz)
    adf_r, adf_i = pbpUabnm_calc_xi(sp, hscfg[:, ni], ni, nx, ny, nz; tapemap=tapemap)
    #更改位置i上的mu
    Ui[ni] += 0.0001
    sp = default_splitting(Nt, hk, Ui; Z2=false)
    sslen, bmats, allbmats2, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println(rdiff[1, :]/0.0001)
    #println(rdiff[2, :]/0.0001)
    #println(rdiff[3, :]/0.0001)
    #println(rdiff[4, :]/0.0001)
    println(rdiff[Nx+1, :]/0.0001)
    #println(rdiff[Nx+2, :]/0.0001)
    #println(rdiff[Nx+3, :]/0.0001)
    #println(rdiff[Nx+4, :]/0.0001)
    #
    println(adf_r[1, 1, :])
    println(adf_r[1, 2, :])
    #println(adf_i[1, 1, :])
    #println(adf_i[1, 2, :])
end


function Ubar_example()
    #通过一个[t1, ..., tn]在确定的hscfg下计算出sign和 d ln s / dti
    #微调ti计算sign，验证数值正确性
    L = 3
    Nx = 3*L^2
    nx = zeros(Nx)
    ny = zeros(Nx)
    nz = ones(Nx)
    #
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    hk = kron([1 0; 0 1], hk)
    #lattice_hexagonal(ComplexF64, L, -1.0+0.0im; λ=1.0+0.0im)
    pet = rand(ComplexF64, 2*Nx, 2*Nx)
    pet = pet + adjoint(pet)
    hk += pet
    Ui = 1*ones(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=false)
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    logS1 = sgn_scratch(ss)
    tapemap = pbpU_calc_tapemap(sp)
    Ubar = meas_gradU(ss, allbmats, sp, hscfg; tapemap=tapemap)
    #println(nybar)
    #println(nzbar)
    #
    admat = zeros(Float64, Nx, 1)
    ndmat = zeros(Float64, Nx, 1)
    for xi in Base.OneTo(Nx)
        Ui2 = copy(Ui)
        Ui2[xi] += 0.001
        sp2 = default_splitting(Nt, hk, Ui2; Z2=false)
        sslen, bmats, allbmats, ss2 = initialize_SS(Nt, Ng, sp2, hscfg, nx, ny, nz)
        logS2 = sgn_scratch(ss2)
        #d log S = - d logabs(w)
        #所以这里需要加一个负号
        #println(-logS2 + logS1)
        nd = -logS2 + logS1
        ad = 0.001*Ubar[xi]
        #println(nd, " ", ad)
        admat[xi, 1] = ad
        ndmat[xi, 1] = nd
    end
    println(admat, ndmat, admat-ndmat)
    @testset "-∂log(s)/∂x" begin
        @test all(isapprox.(ndmat, admat, rtol=1e-2, atol=2e-4))
    end
end


function Ubar_Quad1()
    #通过一个[t1, ..., tn]在确定的hscfg下计算出sign和 d ln s / dti
    #微调ti计算sign，验证数值正确性
    L = 3
    Nx = 3*L^2
    phis = rand(Nx)
    #
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    hk = kron([1 0; 0 1], hk)
    #lattice_hexagonal(ComplexF64, L, -1.0+0.0im; λ=1.0+0.0im)
    pet = rand(ComplexF64, 2*Nx, 2*Nx)
    pet = pet + adjoint(pet)
    hk += pet
    Ui = 1*ones(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = floor.(4rand(Nx)) .- 2
        cfg = map(x -> x >= 0 ? x+1 : x, cfg)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=false)
    sslen, bmats, allbmats0, ss = initialize_SS_Quad1(Nt, Ng, sp, hscfg, phis)
    logS1 = sgn_scratch(ss)
    #包含权重
    ηdict = Dict(
        -2 => -3.301360247771569, 
        -1 => -1.049295246550581,
        1 => 1.049295246550581,
        2 => 3.301360247771569
    )
    γdict = Dict(
        -2 => 0.18350341907227408, 
        -1 => 1.8164965809277258,
        1 => 1.8164965809277258,
        2 => 0.18350341907227408
    )
    logS1w = logS1
    for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
        logS1w += sqrt(sp.dt*Ui[xi]/2)*ηdict[hscfg[bi, xi]]*phis[xi]
        logS1w += log(γdict[hscfg[bi, xi]])
    end; end
    tapemap = pbpU_calc_tapemap_Quad1(sp, phis)
    Ubar = meas_gradU_Quad1(ss, allbmats0, sp, hscfg, phis; tapemap=tapemap, ignorew=true)
    Ubarw = meas_gradU_Quad1(ss, allbmats0, sp, hscfg, phis; tapemap=tapemap, ignorew=false)
    #println(nybar)
    #println(nzbar)
    #
    admat = zeros(Float64, Nx, 1)
    ndmat = zeros(Float64, Nx, 1)
    admatw = zeros(Float64, Nx, 1)
    ndmatw = zeros(Float64, Nx, 1)
    ni = 1
    for xi in Base.OneTo(Nx)
        Ui2 = copy(Ui)
        Ui2[xi] += 0.001
        sp2 = default_splitting(Nt, hk, Ui2; Z2=false)
        sslen, bmats, allbmats, ss2 = initialize_SS_Quad1(Nt, Ng, sp2, hscfg, phis)
        if xi == ni
            println("---")
            adr, adi = pbpU_calc_xi_Quad1(sp, hscfg[:, ni], ni, phis; tapemap=tapemap)
            println(allbmats[1][[1, 1+Nx], :] - allbmats0[1][[1, 1+Nx], :])
            println(adr[1, :, :])
            println(adi[1, :, :])
            println("---")
        end
        logS2 = sgn_scratch(ss2)
        logS2w = logS2
        for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
            logS2w += sqrt(sp.dt*Ui2[xi]/2)*ηdict[hscfg[bi, xi]]*phis[xi]
            logS2w += log(γdict[hscfg[bi, xi]])
        end; end
        #d log S = - d logabs(w)
        #所以这里需要加一个负号
        #println(-logS2 + logS1)
        nd = -logS2 + logS1
        ad = 0.001*Ubar[xi]
        #println(nd, " ", ad)
        admat[xi, 1] = ad
        ndmat[xi, 1] = nd
        #
        ndw = -logS2w + logS1w
        adw = 0.001*Ubarw[xi]
        #println(nd, " ", ad)
        admatw[xi, 1] = adw
        ndmatw[xi, 1] = ndw
    end
    println(admat, ndmat, admat-ndmat)
    println(admatw, ndmatw, admatw-ndmatw)
    @testset "-∂log(s)/∂x" begin
        @test all(isapprox.(ndmat, admat, rtol=1e-2, atol=2e-4))
    end
    @testset "-∂wgt/∂x -∂log(s)/∂x" begin
        @test all(isapprox.(ndmatw, admatw, rtol=1e-2, atol=2e-4))
    end
end

#Ubar_example()
pbpUabnm_example()
#Ubar_Quad1()
